%% consumption_FK.m
%
% Code to calculate consumption using Feynman-Kac formula
% Note: assumes spot interest rate remains constant (since in steady state)
%--------------------------------------------------------------------------

% Preliminaries
assert(rb_spot_ind_stationary == 2 | rb_spot_ind_stationary == 3, 'Fix hard-coded interest rates');
Delta_FK = 1/24;
T = 1;


%--------------------------------------------------------------------------
%% Get relevant transition matrices
A3 = sparse([]);
for rm_ind = 1:Nr
    A3 = [A3; [sparse(M,M*(rm_ind-1)), AStoreSS{3,rm_ind}, sparse(M,M*(Nr-rm_ind))]];
end
A2 = sparse([]);
for rm_ind = 1:Nr
    A2 = [A2; [sparse(M,M*(rm_ind-1)), AStoreSS{2,rm_ind}, sparse(M,M*(Nr-rm_ind))]];
end

%Account for retirement (see solveKF for details)
if deathFlag == 1
    bMinInd = discretize(0,b);
    bMaxInd = discretize(zstates(1)/2,b);
    wt = la_BY/(bMaxInd-bMinInd+1);
    
    for rb_ind = 2:3
        DDrel = -la_BY*speye(Mall,Mall);
        for rm_ind = 1:Nr
            for nz = 1:Nz
                %Finish
                relevant_col_start = M*(rb_ind-1) + Nb*Na*(nz-1) + Nb*(Na-1) + bMinInd;
                relevant_cols = [relevant_col_start:relevant_col_start+bMaxInd-bMinInd];

                %Start
                relevant_rows = M*(rm_ind-1) + [1+Nb*Na*(nz-1):Nb*Na*nz];
                relevant_rows = repmat(relevant_rows, length(relevant_cols), 1);

                DDrel = DDrel + sparse(relevant_rows(:), repmat(relevant_cols',Na*Nb,1), wt, Mall, Mall);
            end
        end
        if rb_ind == 2
            DD2 = DDrel;
        elseif rb_ind == 3
            DD3 = DDrel;
        end
    end
else
    DD2 = sparse(Mall,Mall);
    DD3 = sparse(Mall,Mall);
end

A3 = A3 + forcedRefiStore{3} + DD3;
A2 = A2 + forcedRefiStore{2} + DD2;
    
if slow_refi ~= 1
    la_slowrefi_tmp = 1000; %helps numerically to have non-infinite transition
else
    la_slowrefi_tmp = la_slowrefi;
    assert(faster_prepay == 0, 'Code here assumes same refi/prepay rate');
end
%Convert intervention matrix to a slower matrix
Ctmp = interventionStore{3};
Ctmp = Ctmp - spdiags(Ctmp,0).*speye(Mall);
Ctmp = la_slowrefi_tmp*Ctmp;
    tmp = find(any(Ctmp,2) == 1); %get adjustment indices
Ctmp = Ctmp - sparse(tmp, tmp, la_slowrefi_tmp*ones(length(tmp),1), Mall, Mall);

AFK3 = A3 + Ctmp;

Ctmp = interventionStore{2};
Ctmp = Ctmp - spdiags(Ctmp,0).*speye(Mall);
Ctmp = la_slowrefi_tmp*Ctmp;
    tmp = find(any(Ctmp,2) == 1);
Ctmp = Ctmp - sparse(tmp, tmp, la_slowrefi_tmp*ones(length(tmp),1), Mall, Mall);

AFK2 = A2 + Ctmp;

if max(abs(sum(AFK3, 2))) > 1e-10
    warning('Some rows of A do not sum to 0')
end
if max(abs(sum(AFK2, 2))) > 1e-10
    warning('Some rows of A do not sum to 0')
end


%--------------------------------------------------------------------------
%% Calculate Cumulative Consumption
c_flow3 = [];
for rind = 1:Nr
    tmp = cStore{3, rind};
    c_flow3 = [c_flow3; tmp(:)];
end

C_FK3 = zeros(Mall,1);
for ind = 1:length([Delta_FK:Delta_FK:T])
    C_FK3 = ((1/Delta_FK)*speye(Mall,Mall) - AFK3) \ (c_flow3 + (1/Delta_FK) * C_FK3);
    if ind*Delta_FK == 1/4
        C_FK3_0pt25 = reshape(C_FK3, Nb,Na,Nz,Nr);
    end
    if ind*Delta_FK == 1/2
        C_FK3_0pt5 = reshape(C_FK3, Nb,Na,Nz,Nr);
    end
    if ind*Delta_FK == 1
        C_FK3_1 = reshape(C_FK3, Nb,Na,Nz,Nr);
    end
    if ind*Delta_FK == 2
        C_FK3_2 = reshape(C_FK3, Nb,Na,Nz,Nr);
    end
end


% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

c_flow2 = [];
for rind = 1:Nr
    tmp = cStore{2, rind};
    c_flow2 = [c_flow2; tmp(:)];
end

C_FK2 = zeros(Mall,1);
for ind = 1:length([Delta_FK:Delta_FK:T])
    C_FK2 = ((1/Delta_FK)*speye(Mall,Mall) - AFK2) \ (c_flow2 + (1/Delta_FK) * C_FK2);
    if ind*Delta_FK == 1/4
        C_FK2_0pt25 = reshape(C_FK2, Nb,Na,Nz,Nr);
    end
    if ind*Delta_FK == 1/2
        C_FK2_0pt5 = reshape(C_FK2, Nb,Na,Nz,Nr);
    end
    if ind*Delta_FK == 1
        C_FK2_1 = reshape(C_FK2, Nb,Na,Nz,Nr);
    end
    if ind*Delta_FK == 2
        C_FK2_2 = reshape(C_FK2, Nb,Na,Nz,Nr);
    end
end



%--------------------------------------------------------------------------
%% Calculate Expected Future Consumption Rate
c_flow3 = [];
for rind = 1:Nr
    tmp = cStore{3, rind};
    c_flow3 = [c_flow3; tmp(:)];
end

Ec_FK3 = c_flow3; %terminal condition
for ind = 1:length([Delta_FK:Delta_FK:T])
    Ec_FK3 = ((1/Delta_FK)*speye(Mall,Mall) - AFK3) \ ((1/Delta_FK) * Ec_FK3);
    if ind*Delta_FK == 1/4
        Ec_FK3_0pt25 = reshape(Ec_FK3, Nb,Na,Nz,Nr);
    end
    if ind*Delta_FK == 1/2
        Ec_FK3_0pt5 = reshape(Ec_FK3, Nb,Na,Nz,Nr);
    end
    if ind*Delta_FK == 1
        Ec_FK3_1 = reshape(Ec_FK3, Nb,Na,Nz,Nr);
    end
    if ind*Delta_FK == 2
        Ec_FK3_2 = reshape(Ec_FK3, Nb,Na,Nz,Nr);
    end
end


% - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -

c_flow2 = [];
for rind = 1:Nr
    tmp = cStore{2, rind};
    c_flow2 = [c_flow2; tmp(:)];
end

Ec_FK2 = c_flow2; %terminal condition
for ind = 1:length([Delta_FK:Delta_FK:T])
    Ec_FK2 = ((1/Delta_FK)*speye(Mall,Mall) - AFK2) \ ((1/Delta_FK) * Ec_FK2);
    if ind*Delta_FK == 1/4
        Ec_FK2_0pt25 = reshape(Ec_FK2, Nb,Na,Nz,Nr);
    end
    if ind*Delta_FK == 1/2
        Ec_FK2_0pt5 = reshape(Ec_FK2, Nb,Na,Nz,Nr);
    end
    if ind*Delta_FK == 1
        Ec_FK2_1 = reshape(Ec_FK2, Nb,Na,Nz,Nr);
    end
    if ind*Delta_FK == 2
        Ec_FK2_2 = reshape(Ec_FK2, Nb,Na,Nz,Nr);
    end
end



%--------------------------------------------------------------------------
%% Clean Up
clear DDrel DD2 DD3 Ctmp c_flow2 c_flow3 C_FK2 C_FK3 Ec_FK2 Ec_FK3 la_slowrefi_tmp;


